# libraries
library(tidyverse)
library(cregg)
library(hrbrthemes)
library(estimatr)
library(ggrepel)
library(stargazer)
library(here)

# rep file path
rep_path = "r/replication"

# load data
df = read_rds(here(rep_path, "data", "tidy-cjt.rds")) %>% 
  mutate(task_number = as.factor(task_number), 
         profile = as.factor(profile))


# attribute dictionary to clean up names
cleanup = tribble(~feature, ~label, 
                  "victim_gender", "Gender of victim", 
                  "crime", "Type of crime", 
                  "perp", "Type of perpetrator", 
                  "social_distance", "Closeness to perpetrator", 
                  "perp_gender", "Gender of perpetrator")


# reorder factors for AMCEs
df = 
  df %>% 
  mutate(crime = fct_relevel(crime, "Fight", "Theft", "Extortion", "Sexual assault"), 
         perp = fct_relevel(perp, "Common person", "Gang member", "Police"), 
         social_distance = fct_relevel(social_distance, "Not from your community"))


# stupid geom_label_repel problem
update_geom_defaults("label", list(family = "IBM Plex Sans Condensed"))
update_geom_defaults("label_repel", list(family = "IBM Plex Sans Condensed"))



# main result tables: Table A.3 & Table A.4 -----------------------------------------------------------


form = chosen_dummy ~ victim_gender + crime + perp + social_distance + perp_gender

amce(data = df, formula = form, id = ~ r_id, feature_order=c("social_distance", "perp_gender", "victim_gender", "crime", "perp")) %>% 
  left_join(cleanup) %>% 
  select(Attribute = label, Level = level, Estimate = estimate, 
         SE = std.error) %>% 
  mutate(Estimate = round(Estimate, 3), 
         SE = round(SE, 3)) %>% 
  mutate(Estimate = ifelse(Estimate == 0, "(baseline)", Estimate),
         SE = ifelse(is.na(SE), " ", SE),
         Attribute = ifelse(Estimate == "(baseline)", Attribute, " ")) %>% 
  stargazer(type = "latex", title = "AMCE estimates from main model.", 
            summary = FALSE, label = "tab:amce", 
            out = "r/replication/figures/amce-table.tex")
  
cj(data = df, formula = form, id = ~ r_id, estimate = "mm", feature_order=c("social_distance", "perp_gender", "victim_gender", "crime", "perp")) %>% 
  left_join(cleanup) %>% 
  select(Attribute = label, Level = level, Estimate = estimate, 
         SE = std.error) %>% 
  mutate(Estimate = round(Estimate, 3), 
         SE = round(SE, 3)) %>% 
  stargazer(type = "latex", title = "Marginal Mean estimates from main model.", 
            summary = FALSE, label = "tab:mm", 
            out = "r/replication/figures/mm-table.tex")




# Visualize main results: Figure 4 --------------------------------------------------



# define function to get both amce and mms
get_both = function(data, formula, dictionary)
{
  # estimate AMCEs
  amces = cj(data, formula, id = ~r_id)
  
  # clean up labels
  amces = 
    amces %>% 
    left_join(dictionary, by = "feature")
  
  # estimate marginal means
  mms = mm(as.data.frame(data), formula,
           id = ~r_id)
  
  # clean up labels
  mms = 
    mms %>% 
    left_join(dictionary, by = "feature")
  
  
  # clean up data for plotting
  pDat = 
    rbind(mms, amces) %>% 
    # get rid of perp_
    mutate(level = str_replace(level, "perp_", "")) %>% 
    # add baseline (0 and .5)
    mutate(baseline = ifelse(statistic == "amce", 0, .5)) %>% 
    # fix statistic
    mutate(statistic = case_when(statistic == "amce" ~ "Average Marginal Component Effect", 
                                 statistic == "mm" ~ "Marginal Means"))
  
  # output
  return(pDat)
}

# define formula
form = chosen_dummy ~ victim_gender + crime + perp + social_distance + perp_gender
pDat_chosen = get_both(data = df, formula = form, dictionary = cleanup)


pDat_chosen %>% 
  ggplot(aes(x = level, y = estimate, ymin = lower, ymax = upper, label = level)) + 
  geom_pointrange() + 
  facet_grid(rows = vars(label), cols = vars(statistic), scales = "free", 
             switch = "y") + 
  coord_flip() + 
  geom_hline(data = pDat_chosen, aes(yintercept = baseline), lty = 2) + 
  theme_light(base_family = "IBM Plex Sans Condensed") + 
  theme(panel.grid.major = element_blank(), 
        axis.text.y = element_blank(), 
        strip.background =element_rect(fill = "black"),
        strip.text =element_text(color = "white", face = "bold")) + 
  geom_label_repel(nudge_x = .1) + 
  labs(y = NULL, 
       x = NULL) + 
  scale_y_percent()

# save output
ggsave(filename = here(rep_path, "figures", "cjt-results-chosen.pdf"), 
       device = cairo_pdf, width = 8, height = 10)



# repeat for punishment: Figure 5---------------------------------------------------

# define formula
form = punish_dummy ~ victim_gender + crime + perp + social_distance + perp_gender
pDat_punish = get_both(data = df, formula = form, dictionary = cleanup)


# combine for plotting
punish_chosen = 
  rbind(mutate(pDat_chosen, outcome = "Subject to community justice"), 
        mutate(pDat_punish, outcome = "Punished more harshly"))

punish_chosen %>% 
  filter(statistic == "Marginal Means") %>% 
  ggplot(aes(x = level, y = estimate, ymin = lower,
             ymax = upper, label = level, 
             shape = outcome, 
             color = outcome)) + 
  geom_pointrange(position = 
                    position_dodge(width = .2), alpha = .8,
                  size = .8) + 
  coord_flip() + 
  geom_hline(yintercept = .5, lty = 2) + 
  facet_wrap(vars(label), scales = "free") + 
  theme_light(base_family = "Fira Sans") + 
  theme(panel.grid.major = element_blank(), 
        axis.text.y = element_blank(), 
        legend.position = "top",
        axis.ticks.y = element_blank(),
        strip.background =element_rect(fill = "black"),
        strip.text =element_text(color = "white", face = "bold")) + 
  labs(y = "Pr(crime with attribute selected)\nMarginal means", 
       x = NULL, 
       shape = "Crime should be:", 
       color = "Crime should be:") + 
  scale_y_percent() + 
  geom_label_repel(data = filter(punish_chosen, 
                                 outcome == "Punished more harshly", 
                                 statistic == "Marginal Means"), 
                   nudge_x = -.2, color = "black", 
                   segment.size = 0, 
                   family = "Fira Sans") + 
  scale_color_brewer(palette = "Dark2")


# save output
ggsave(filename = here(rep_path, "figures", 
                       "cjt-results-both.pdf"), 
       device = cairo_pdf, width = 8, height = 10)



# diagnostics: Figure A.3, Figure A.4, Figure A.2, Table A.5 -------------------------------------------------------------


form = chosen_dummy ~ victim_gender + crime + perp + social_distance + perp_gender

## preference heterogeneity across tasks
task_prefs = cj(as.data.frame(df), form,
           id = ~r_id, by = ~task_number, estimate = "mm")

task_prefs %>% 
  left_join(cleanup) %>% 
  mutate(level = str_replace(level, "perp_", "")) %>% 
  ggplot(aes(x = level, y = estimate, ymin = lower, ymax = upper, label = level, 
             shape = BY, color = BY)) + 
  geom_pointrange(position = position_dodge(width = .5), alpha = .8) + 
  coord_flip() +
  theme_light(base_family = "IBM Plex Sans Condensed") + 
  theme(panel.grid.major = element_blank(), legend.position = "top",
        strip.background =element_rect(fill = "black"),
        strip.text =element_text(color = "white", face = "bold")) + 
  labs(y = "Marginal means estimate", 
       x = NULL, 
       color = "Task number:", 
       shape = "Task number:") + 
  scale_y_percent() + 
  scale_color_brewer(palette = "Dark2") +
  facet_grid(rows = vars(label), scales = "free", switch = "x")


# save output
ggsave(filename = here(rep_path, "figures", 
                       "diagnostics-tasks.pdf"), 
       device = cairo_pdf, width = 6, height = 10)


# randomization check
cj_freqs(as.data.frame(df), form,
              id = ~r_id) %>% 
  left_join(cleanup) %>% 
  mutate(level = str_replace(level, "perp_", "")) %>% 
  ggplot(aes(x = level, y = estimate, label = level, fill = label)) + 
  geom_col(alpha = .9) + 
  coord_flip() +
  theme_light(base_family = "IBM Plex Sans Condensed") + 
  theme(legend.position = "none") + 
  labs(y = "Counts", 
       x = "Attribute level") + 
  scale_fill_manual(values = MetBrewer::met.brewer(name = "Degas")) + 
  facet_grid(rows = vars(label), scales = "free")

# save output
ggsave(filename = here(rep_path, "figures", "randomization.pdf"), 
       device = cairo_pdf,  width = 6, height = 10)


# left-right preference
left_right = cj(as.data.frame(df), form,
        id = ~r_id, by = ~profile, estimate = "mm")

# table format
left_right_anova = cj_anova(as.data.frame(df), form,
                id = ~r_id, by = ~profile)


stargazer(left_right_anova, type = "latex", summary = F, 
          label = "tab:left-right",
          title = "Diagnostic: do respondents choose one profile at higher rates than another.", style = "ajps", 
          out = here(rep_path, "figures", "left-right-table.tex"))



# plot
left_right %>% 
  mutate(BY = case_when(BY == "A" ~ "Top", 
                             BY == "B" ~ "Bottom")) %>% 
  left_join(cleanup) %>% 
  mutate(level = str_replace(level, "perp_", "")) %>% 
  ggplot(aes(x = level, y = estimate, ymin = lower, ymax = upper, label = level, 
             shape = BY, color = BY)) + 
  geom_pointrange(position = position_dodge(width = .5)) + 
  coord_flip() +
  theme_light(base_family = "IBM Plex Sans Condensed") + 
  theme(panel.grid.major = element_blank(),
        legend.position = "top",
        strip.background =element_rect(fill = "black"),
        strip.text =element_text(color = "white", face = "bold")) + 
  labs(y = "Marginal means estimate", 
       x = NULL, 
       color = "Profile position:", 
       shape = "Profile position:") + 
  scale_y_percent() + 
  scale_color_brewer(palette = "Dark2") +
  facet_grid(rows = vars(label), scales = "free", switch = "x") + 
  geom_hline(yintercept = .5, lty = 2)

# save output
ggsave(filename = here(rep_path, "figures", 
                       "diagnostics-left-right.pdf"), 
       device = cairo_pdf,  width = 6, height = 10)
